package com.shujia.mr;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Partitioner;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

import java.io.IOException;

public class Demo05SkewDataMR {
    public static class MyMapper extends Mapper<LongWritable, Text, Text, IntWritable> {
        @Override
        protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, IntWritable>.Context context) throws IOException, InterruptedException {
            String line = value.toString();
            // 将每一行数据按照逗号/空格进行切分
            for (String word : line.split("[,\\s]")) {
                // 使用context.write将数据发送到下游
                // 将每个单词变成 单词,1 形式
                // 对数据倾斜的Key加上随机后缀
                if ("hadoop".equals(word)) {
                    // 随机生成 0 1 2
                    int prefix = (int) (Math.random() * 3);
                    context.write(new Text(word + "_" + prefix), new IntWritable(1));
                } else {
                    context.write(new Text(word), new IntWritable(1));
                }
            }
        }
    }


    public static class MyReducer extends Reducer<Text, IntWritable, Text, IntWritable> {
        @Override
        protected void reduce(Text key, Iterable<IntWritable> values, Reducer<Text, IntWritable, Text, IntWritable>.Context context) throws IOException, InterruptedException {
            // 统计每个单词的数量
            int cnt = 0;
            for (IntWritable value : values) {
                cnt = cnt + value.get();
            }
            context.write(key, new IntWritable(cnt));
        }
    }

    // Driver端：组装（调度）及配置任务
    // 可以通过args接收参数
    // 本任务接收两个参数：输入路径、输出路径
    public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
        Configuration conf = new Configuration();
        // 创建Job
        Job job = Job.getInstance(conf);

        // 配置任务
        job.setJobName("Demo05SkewDataMR");
        job.setJarByClass(Demo05SkewDataMR.class);

        // 设置自定义分区器
        job.setPartitionerClass(SkewPartitioner.class);

        // 手动设置Reduce的数量
        // 最终输出到HDFS的文件数量等于Reduce的数量
        job.setNumReduceTasks(3);

        // 配置Map端
        job.setMapperClass(MyMapper.class);
        job.setMapOutputKeyClass(Text.class);
        job.setMapOutputValueClass(IntWritable.class);

        // 配置Reduce端
        job.setReducerClass(MyReducer.class);
        job.setOutputKeyClass(Text.class);
        job.setOutputValueClass(IntWritable.class);

        // 验证args的长度
        if (args.length != 2) {
            System.out.println("请传入输入输出目录！");
            return;
        }

        String input = args[0];
        String output = args[1];

        // 配置输入输出的路径
        FileInputFormat.addInputPath(job, new Path(input));

        Path ouputPath = new Path(output);
        // 通过FileSystem来实现覆盖写入
        FileSystem fs = FileSystem.get(conf);
        if (fs.exists(ouputPath)) {
            fs.delete(ouputPath, true);
        }
        // 该目录不能存在，会自动创建，如果已存在则会直接报错
        FileOutputFormat.setOutputPath(job, ouputPath);

        // 启动任务
        // 等待任务的完成
        job.waitForCompletion(true);


    }
}

// 自定义分区：在Map阶段给key加上随机后缀，基于后缀返回不同的分区编号
class SkewPartitioner extends Partitioner<Text, IntWritable> {

    @Override
    public int getPartition(Text text, IntWritable intWritable, int numPartitions) {
        String key = text.toString();
        int partitions = 0;
        // 只对数据倾斜的key做特殊处理
        if ("hadoop".equals(key.split("_")[0])) {
            switch (key) {
//                case "hadoop_0":
//                    partitions = 0;
//                    break;
                case "hadoop_1":
                    partitions = 1;
                    break;
                case "hadoop_2":
                    partitions = 2;
                    break;
            }
        } else {
            // 正常的key还是按照默认的Hash取余进行分区
            partitions = (key.hashCode() & Integer.MAX_VALUE) % numPartitions;
        }
        return partitions;
    }
}
